package stream_test

import (
	"context"
	"testing"

	"github.com/Khan/genqlient/graphql"
	"github.com/stretchr/testify/assert"
	"github.com/wandb/wandb/core/internal/featurechecker"
	"github.com/wandb/wandb/core/internal/filestream"
	"github.com/wandb/wandb/core/internal/filetransfer"
	"github.com/wandb/wandb/core/internal/gqlmock"
	"github.com/wandb/wandb/core/internal/mailbox"
	"github.com/wandb/wandb/core/internal/observability"
	"github.com/wandb/wandb/core/internal/observabilitytest"
	"github.com/wandb/wandb/core/internal/runfiles"
	"github.com/wandb/wandb/core/internal/runhandle"
	"github.com/wandb/wandb/core/internal/runworktest"
	wbsettings "github.com/wandb/wandb/core/internal/settings"
	"github.com/wandb/wandb/core/internal/stream"
	"github.com/wandb/wandb/core/internal/watchertest"
	"github.com/wandb/wandb/core/pkg/artifacts"
	spb "github.com/wandb/wandb/core/pkg/service_go_proto"
	"go.uber.org/mock/gomock"
	"google.golang.org/protobuf/types/known/wrapperspb"
)

const validLinkArtifactResponse = `{
	"linkArtifact": { "versionIndex": 0 }
}`

type testFixtures struct {
	Sender    *stream.Sender
	RunHandle *runhandle.RunHandle
	Settings  *wbsettings.Settings
	Logger    *observability.CoreLogger
}

func makeSender(t *testing.T, client graphql.Client) testFixtures {
	t.Helper()
	runWork := runworktest.New()
	logger := observabilitytest.NewTestLogger(t)
	settings := wbsettings.From(&spb.Settings{
		RunId:   &wrapperspb.StringValue{Value: "run1"},
		Console: &wrapperspb.StringValue{Value: "off"},
		ApiKey:  &wrapperspb.StringValue{Value: "test-api-key"},
	})
	backend := stream.NewBackend(logger, settings)
	fileStreamFactory := &filestream.FileStreamFactory{
		Logger:   logger,
		Printer:  observability.NewPrinter(),
		Settings: settings,
	}
	fileTransferManager := stream.NewFileTransferManager(
		filetransfer.NewFileTransferStats(),
		logger,
		settings,
	)
	runfilesUploaderFactory := &runfiles.UploaderFactory{
		FileTransfer: fileTransferManager,
		FileWatcher:  watchertest.NewFakeWatcher(),
		GraphQL:      client,
		Logger:       logger,
		Settings:     settings,
	}
	runHandle := runhandle.New()

	senderFactory := stream.SenderFactory{
		Logger:                  logger,
		Settings:                settings,
		Backend:                 backend,
		FileStreamFactory:       fileStreamFactory,
		FileTransferManager:     fileTransferManager,
		RunfilesUploaderFactory: runfilesUploaderFactory,
		Mailbox:                 mailbox.New(),
		GraphqlClient:           client,
		FeatureProvider:         featurechecker.NewServerFeaturesCache(nil, logger),
		RunHandle:               runHandle,
	}
	return testFixtures{
		Sender:    senderFactory.New(runWork),
		RunHandle: runHandle,
		Settings:  settings,
		Logger:    logger,
	}
}

// Verify that arguments are properly passed through to graphql
func TestSendLinkArtifact(t *testing.T) {
	mockGQL := gqlmock.NewMockClient()
	x := makeSender(t, mockGQL)

	// 1. When both clientId and serverId are sent, serverId is used
	linkArtifact := &spb.Record{
		RecordType: &spb.Record_Request{
			Request: &spb.Request{
				RequestType: &spb.Request_LinkArtifact{
					LinkArtifact: &spb.LinkArtifactRequest{
						ClientId:         "clientId",
						ServerId:         "serverId",
						PortfolioName:    "portfolioName",
						PortfolioEntity:  "portfolioEntity",
						PortfolioProject: "portfolioProject",
					},
				},
			},
		},
		Control: &spb.Control{
			MailboxSlot: "junk",
		},
	}

	mockGQL.StubMatchOnce(
		gqlmock.WithOpName("LinkArtifact"),
		validLinkArtifactResponse,
	)
	x.Sender.SendRecord(linkArtifact)
	<-x.Sender.ResponseChan()

	requests := mockGQL.AllRequests()
	assert.Len(t, requests, 1)
	gqlmock.AssertVariables(t,
		requests[0],
		gqlmock.GQLVar("projectName", gomock.Eq("portfolioProject")),
		gqlmock.GQLVar("entityName", gomock.Eq("portfolioEntity")),
		gqlmock.GQLVar("artifactPortfolioName", gomock.Eq("portfolioName")),
		gqlmock.GQLVar("clientId", gomock.Eq(nil)),
		gqlmock.GQLVar("artifactId", gomock.Eq("serverId")))

	// 2. When only clientId is sent, clientId is used
	linkArtifact = &spb.Record{
		RecordType: &spb.Record_Request{
			Request: &spb.Request{
				RequestType: &spb.Request_LinkArtifact{
					LinkArtifact: &spb.LinkArtifactRequest{
						ClientId:         "clientId",
						ServerId:         "",
						PortfolioName:    "portfolioName",
						PortfolioEntity:  "portfolioEntity",
						PortfolioProject: "portfolioProject",
					},
				},
			},
		},
	}

	mockGQL.StubMatchOnce(
		gqlmock.WithOpName("LinkArtifact"),
		validLinkArtifactResponse,
	)
	x.Sender.SendRecord(linkArtifact)
	<-x.Sender.ResponseChan()

	requests = mockGQL.AllRequests()
	assert.Len(t, requests, 2)
	gqlmock.AssertVariables(t,
		requests[1],
		gqlmock.GQLVar("projectName", gomock.Eq("portfolioProject")),
		gqlmock.GQLVar("entityName", gomock.Eq("portfolioEntity")),
		gqlmock.GQLVar("artifactPortfolioName", gomock.Eq("portfolioName")),
		gqlmock.GQLVar("clientId", gomock.Eq("clientId")),
		gqlmock.GQLVar("artifactId", gomock.Eq(nil)))

	// 3. When only serverId is sent, serverId is used
	linkArtifact = &spb.Record{
		RecordType: &spb.Record_Request{
			Request: &spb.Request{
				RequestType: &spb.Request_LinkArtifact{
					LinkArtifact: &spb.LinkArtifactRequest{
						ClientId:         "",
						ServerId:         "serverId",
						PortfolioName:    "portfolioName",
						PortfolioEntity:  "portfolioEntity",
						PortfolioProject: "portfolioProject",
					},
				},
			},
		},
	}

	mockGQL.StubMatchOnce(
		gqlmock.WithOpName("LinkArtifact"),
		validLinkArtifactResponse,
	)
	x.Sender.SendRecord(linkArtifact)
	<-x.Sender.ResponseChan()

	requests = mockGQL.AllRequests()
	assert.Len(t, requests, 3)
	gqlmock.AssertVariables(t,
		requests[2],
		gqlmock.GQLVar("projectName", gomock.Eq("portfolioProject")),
		gqlmock.GQLVar("entityName", gomock.Eq("portfolioEntity")),
		gqlmock.GQLVar("artifactPortfolioName", gomock.Eq("portfolioName")),
		gqlmock.GQLVar("clientId", gomock.Eq(nil)),
		gqlmock.GQLVar("artifactId", gomock.Eq("serverId")))
}

func TestSendUseArtifact(t *testing.T) {
	mockGQL := gqlmock.NewMockClient()
	x := makeSender(t, mockGQL)

	useArtifact := &spb.Record{
		RecordType: &spb.Record_UseArtifact{
			UseArtifact: &spb.UseArtifactRecord{
				Id:      "artifactId",
				Type:    "job",
				Name:    "artifactName",
				Partial: nil,
			},
		},
	}
	// verify doesn't panic if used job artifact
	x.Sender.SendRecord(useArtifact)

	// verify doesn't panic if partial job is broken
	useArtifact = &spb.Record{
		RecordType: &spb.Record_UseArtifact{
			UseArtifact: &spb.UseArtifactRecord{
				Id:   "artifactId",
				Type: "job",
				Name: "artifactName",
				Partial: &spb.PartialJobArtifact{
					JobName: "jobName",
					SourceInfo: &spb.JobSource{
						SourceType: "repo",
						Source: &spb.Source{
							Git: &spb.GitSource{
								GitInfo: &spb.GitInfo{
									Commit: "commit",
									Remote: "remote",
								},
							},
						},
					},
				},
			},
		},
	}
	x.Sender.SendRecord(useArtifact)
}

var validFetchOrgEntityFromEntityResponse = `{
	"entity": {
		"organization": {
			"name": "orgName",
			"orgEntity": {
				"name": "orgEntityName_123"
			}
		}
	}
}`

func TestLinkRegistryArtifact(t *testing.T) {
	registryProject := artifacts.RegistryProjectPrefix + "projectName"
	expectLinkArtifactFailure := "expect link artifact to fail, wrong org entity"

	testCases := []struct {
		name              string
		inputOrganization string
		isOldServer       bool
		errorMessage      string
	}{
		{"Link registry artifact with orgName updated server", "orgName", false, ""},
		{
			"Link registry artifact with orgName old server",
			"orgName",
			true,
			expectLinkArtifactFailure,
		},
		{
			"Link registry artifact with orgEntity name updated server",
			"orgEntityName_123",
			false,
			"",
		},
		{"Link registry artifact with orgEntity name old server", "orgEntityName_123", true, ""},
		{"Link registry artifact with short hand path updated server", "", false, ""},
		{"Link registry artifact with short hand path old server", "", true, "unsupported"},
		{
			"Link with wrong org/orgEntity name with updated server",
			"potato",
			false,
			"update the target path",
		},
		{
			"Link with wrong org/orgEntity name with updated server",
			"potato",
			true,
			expectLinkArtifactFailure,
		},
	}
	for _, tc := range testCases {
		mockGQL := gqlmock.NewMockClient()

		newLinker := func(req *spb.LinkArtifactRequest) *artifacts.ArtifactLinker {
			return &artifacts.ArtifactLinker{
				Ctx:           context.Background(),
				LinkArtifact:  req,
				GraphqlClient: mockGQL,
			}
		}

		// If user is on old server, we can't fetch the org entity name so just directly call link artifact
		numExpectedRequests := 3
		if tc.isOldServer {
			numExpectedRequests = 2
		}

		t.Run("Link registry artifact with orgName updated server", func(t *testing.T) {
			req := &spb.LinkArtifactRequest{
				ClientId:              "clientId123",
				PortfolioName:         "portfolioName",
				PortfolioEntity:       "entityName",
				PortfolioProject:      registryProject,
				PortfolioAliases:      nil,
				PortfolioOrganization: tc.inputOrganization,
			}

			var validTypeFieldsResponse string
			if tc.isOldServer {
				validTypeFieldsResponse = `{"TypeInfo": {"fields": []}}`
			} else {
				validTypeFieldsResponse = `{
		"TypeInfo": {
			"fields": [{"name": "orgEntity"}]
		}
	}`
			}
			mockGQL.StubMatchOnce(
				gqlmock.WithOpName("TypeFields"),
				validTypeFieldsResponse,
			)

			mockGQL.StubMatchOnce(
				gqlmock.WithOpName("LinkArtifact"),
				validLinkArtifactResponse,
			)

			mockGQL.StubMatchOnce(
				gqlmock.WithOpName("FetchOrgEntityFromEntity"),
				validFetchOrgEntityFromEntityResponse,
			)

			linker := newLinker(req)
			_, err := linker.Link()
			if err != nil {
				assert.NotEmpty(t, tc.errorMessage)
				assert.ErrorContainsf(t, err, tc.errorMessage,
					"Expected error containing: %s", tc.errorMessage)
				return
			}

			// This error is not triggered by Link() because its linkArtifact that fails
			// and we aren't actually calling it.
			// Here we are checking that the org entity being passed into linkArtifact
			// is wrong so we know the query will fail.
			if tc.errorMessage == expectLinkArtifactFailure {
				requests := mockGQL.AllRequests()
				assert.Len(t, requests, numExpectedRequests)

				// Confirms that the request is incorrectly put into link artifact graphql request
				gqlmock.AssertVariables(t,
					requests[numExpectedRequests-1],
					gqlmock.GQLVar("projectName", gomock.Eq(registryProject)),
					// Here the entity name is not orgEntityName_123 and this will fail if actually called
					gqlmock.GQLVar("entityName", gomock.Not(gomock.Eq("orgEntityName_123"))),
					gqlmock.GQLVar("artifactPortfolioName", gomock.Eq("portfolioName")),
					gqlmock.GQLVar("clientId", gomock.Eq("clientId123")),
					gqlmock.GQLVar("artifactId", gomock.Nil()))
			} else {
				// If no error, check that we are passing in the correct org entity name into linkArtifact
				assert.Empty(t, tc.errorMessage)
				assert.NoError(t, err)
				requests := mockGQL.AllRequests()
				assert.Len(t, requests, numExpectedRequests)

				gqlmock.AssertVariables(t,
					requests[numExpectedRequests-1],
					gqlmock.GQLVar("projectName", gomock.Eq(registryProject)),
					gqlmock.GQLVar("entityName", gomock.Eq("orgEntityName_123")),
					gqlmock.GQLVar("artifactPortfolioName", gomock.Eq("portfolioName")),
					gqlmock.GQLVar("clientId", gomock.Eq("clientId123")),
					gqlmock.GQLVar("artifactId", gomock.Nil()))
			}
		})
	}
}
